#include "Game Engine\Header\Terrain.h"
#include "Game Engine\Header\TextureMgr.h"
#include "Game Engine\Header\Camera.h"
#include "Source\Header\Effects.h"
#include "Source\Header\InputLayouts.h"
#include <fstream>
#include <sstream>

#define SERIALIZE_TO_DISK 1

Terrain::Terrain()
: md3dDevice(0), mVB(0), mIB(0), mLayer0(0), mLayer1(0), 
  mLayer2(0), mLayer3(0), mLayer4(0), mBlendMap(0),
  mVertices(0), mIndices(0),
  mDefaultContactProcessingThreshold( BT_LARGE_FLOAT )
{
}

Terrain::~Terrain()
{
	ReleaseCOM(mVB);
	ReleaseCOM(mIB);
}

float Terrain::width()const
{
	return (mInfo.NumCols-1)*mInfo.CellSpacing;
}

float Terrain::depth()const
{
	return (mInfo.NumRows-1)*mInfo.CellSpacing;
}

float Terrain::getHeight(float x, float z)const
{
	// Transform from terrain local space to "cell" space.
	float c = (x + 0.5f*width()) /  mInfo.CellSpacing;
	float d = (z - 0.5f*depth()) / -mInfo.CellSpacing;

	// Get the row and column we are in.
	int row = (int)floorf(d);
	int col = (int)floorf(c);

	// Grab the heights of the cell we are in.
	// A*--*B
	//  | /|
	//  |/ |
	// C*--*D
	float A = mHeightmap[row*mInfo.NumCols + col];
	float B = mHeightmap[row*mInfo.NumCols + col + 1];
	float C = mHeightmap[(row+1)*mInfo.NumCols + col];
	float D = mHeightmap[(row+1)*mInfo.NumCols + col + 1];

	// Where we are relative to the cell.
	float s = c - (float)col;
	float t = d - (float)row;

	// If upper triangle ABC.
	if( s + t <= 1.0f)
	{
		float uy = B - A;
		float vy = C - A;
		return A + s*uy + t*vy;
	}
	else // lower triangle DCB.
	{
		float uy = C - D;
		float vy = B - D;
		return D + (1.0f-s)*uy + (1.0f-t)*vy;
	}
}

void Terrain::init(ID3D10Device* device, btDynamicsWorld* world, const InitInfo& initInfo)
{
	md3dDevice = device;

	mDynamicsWorld = world;

	mTech          = fx::TerrainFX->GetTechniqueByName("TerrainTech");
	mfxWVPVar      = fx::TerrainFX->GetVariableByName("gWVP")->AsMatrix();
	mfxWorldVar    = fx::TerrainFX->GetVariableByName("gWorld")->AsMatrix();
	mfxDirToSunVar = fx::TerrainFX->GetVariableByName("gDirToSunW")->AsVector();
	mfxLayer0Var   = fx::TerrainFX->GetVariableByName("gLayer0")->AsShaderResource();
	mfxLayer1Var   = fx::TerrainFX->GetVariableByName("gLayer1")->AsShaderResource();
	mfxLayer2Var   = fx::TerrainFX->GetVariableByName("gLayer2")->AsShaderResource();
	mfxLayer3Var   = fx::TerrainFX->GetVariableByName("gLayer3")->AsShaderResource();
	mfxLayer4Var   = fx::TerrainFX->GetVariableByName("gLayer4")->AsShaderResource();
	mfxBlendMapVar = fx::TerrainFX->GetVariableByName("gBlendMap")->AsShaderResource();

	mInfo = initInfo;

	mNumVertices = mInfo.NumRows*mInfo.NumCols;
	mNumFaces    = (mInfo.NumRows-1)*(mInfo.NumCols-1)*2;

	loadHeightmap();
	smooth();

	buildVB();
	buildIB();

	mLayer0   = GetTextureMgr().createTex(initInfo.LayerMapFilename0);
	mLayer1   = GetTextureMgr().createTex(initInfo.LayerMapFilename1);
	mLayer2   = GetTextureMgr().createTex(initInfo.LayerMapFilename2);
	mLayer3   = GetTextureMgr().createTex(initInfo.LayerMapFilename3);
	mLayer4   = GetTextureMgr().createTex(initInfo.LayerMapFilename4);
	mBlendMap = GetTextureMgr().createTex(initInfo.BlendMapFilename);

	D3DXMatrixIdentity( &mWorld );

	InitPhysics();
}

static btVector3*	gVertices=0;
static int*	gIndices=0;
static btBvhTriangleMeshShape* trimeshShape =0;
static btRigidBody* staticBody = 0;
static float waveheight = 50.f;

const float TRIANGLE_SIZE=8.f;



///User can override this material combiner by implementing gContactAddedCallback and setting body0->m_collisionFlags |= btCollisionObject::customMaterialCallback;
inline btScalar	calculateCombinedFriction(float friction0,float friction1)
{
	btScalar friction = friction0 * friction1;

	const btScalar MAX_FRICTION  = 10.f;
	if (friction < -MAX_FRICTION)
		friction = -MAX_FRICTION;
	if (friction > MAX_FRICTION)
		friction = MAX_FRICTION;
	return friction;

}

inline btScalar	calculateCombinedRestitution(float restitution0,float restitution1)
{
	return restitution0 * restitution1;
}



static bool CustomMaterialCombinerCallback(btManifoldPoint& cp,	const btCollisionObject* colObj0,int partId0,int index0,const btCollisionObject* colObj1,int partId1,int index1)
{

	float friction0 = colObj0->getFriction();
	float friction1 = colObj1->getFriction();
	float restitution0 = colObj0->getRestitution();
	float restitution1 = colObj1->getRestitution();

	if (colObj0->getCollisionFlags() & btCollisionObject::CF_CUSTOM_MATERIAL_CALLBACK)
	{
		friction0 = 1.0;//partId0,index0
		restitution0 = 0.f;
	}
	if (colObj1->getCollisionFlags() & btCollisionObject::CF_CUSTOM_MATERIAL_CALLBACK)
	{
		if (index1&1)
		{
			friction1 = 1.0f;//partId1,index1
		} else
		{
			friction1 = 0.f;
		}
		restitution1 = 0.f;
	}

	cp.m_combinedFriction = calculateCombinedFriction(friction0,friction1);
	cp.m_combinedRestitution = calculateCombinedRestitution(restitution0,restitution1);

	//this return value is currently ignored, but to be on the safe side: return false if you don't calculate friction
	return true;
}

extern ContactAddedCallback		gContactAddedCallback;

	const int NUM_VERTS_X = 30;
	const int NUM_VERTS_Y = 30;
	const int totalVerts = NUM_VERTS_X*NUM_VERTS_Y;

void Terrain::setVertexPositions(float waveheight, float offset)
{
	int i;
	int j;

	for ( i=0;i<NUM_VERTS_X;i++)
	{
		for (j=0;j<NUM_VERTS_Y;j++)
		{
			gVertices[i+j*NUM_VERTS_X].setValue((i-NUM_VERTS_X*0.5f)*TRIANGLE_SIZE,
				//0.f,
				waveheight*sinf((float)i+offset)*cosf((float)j+offset),
				(j-NUM_VERTS_Y*0.5f)*TRIANGLE_SIZE);
		}
	}
}

void Terrain::InitPhysics()
{
	#define TRISIZE 10.f

     gContactAddedCallback = CustomMaterialCombinerCallback;

#define USE_TRIMESH_SHAPE 1
#ifdef USE_TRIMESH_SHAPE

	int vertStride = sizeof(btVector3);
	int indexStride = 3*sizeof(int);

	
	const int totalTriangles = 2*(NUM_VERTS_X-1)*(NUM_VERTS_Y-1);

	gVertices = new btVector3[GetNumVertices()];
	gIndices = new int[GetNumIndices()];

	std::vector<TerrainVertex> vertices = GetVertices();
	for( UINT i = 0; i < GetNumVertices(); i ++ )
	{
		if( _isnan( vertices[i].pos.y ) )
			vertices[i].pos.y = 0.0f;

		gVertices[i].setValue( vertices[i].pos.x, vertices[i].pos.y,
			vertices[i].pos.z );
	}

	std::vector<int> indices = GetIndices();
	for( UINT i = 0; i < GetNumIndices(); i ++ )
	{
		gIndices[i] = indices[i];
	}

	UINT verts = GetNumVertices();
	//setVertexPositions(waveheight,0.f);

	//int index=0;
	//for ( UINT i=0;i<NUM_VERTS_X-1;i++)
	//{
	//	for (int j=0;j<NUM_VERTS_Y-1;j++)
	//	{
	//		gIndices[index++] = j*NUM_VERTS_X+i;
	//		gIndices[index++] = j*NUM_VERTS_X+i+1;
	//		gIndices[index++] = (j+1)*NUM_VERTS_X+i+1;

	//		gIndices[index++] = j*NUM_VERTS_X+i;
	//		gIndices[index++] = (j+1)*NUM_VERTS_X+i+1;
	//		gIndices[index++] = (j+1)*NUM_VERTS_X+i;
	//	}
	//}

	m_indexVertexArrays = new btTriangleIndexVertexArray(GetNumIndices() / 3,
		gIndices,
		indexStride,
		GetNumVertices(),(btScalar*) &gVertices[0].x(),vertStride);

	bool useQuantizedAabbCompression = true;

//comment out the next line to read the BVH from disk (first run the demo once to create the BVH)

#ifdef SERIALIZE_TO_DISK


	btVector3 aabbMin(-1000,-1000,-1000),aabbMax(1000,1000,1000);
	
	trimeshShape  = new btBvhTriangleMeshShape(m_indexVertexArrays,useQuantizedAabbCompression,aabbMin,aabbMax);
	m_collisionShapes.push_back(trimeshShape);

	int maxSerializeBufferSize = 1024*1024*50;
	btDefaultSerializer*	serializer = new btDefaultSerializer(maxSerializeBufferSize);
	//serializer->setSerializationFlags(BT_SERIALIZE_NO_BVH);//	or BT_SERIALIZE_NO_TRIANGLEINFOMAP
	serializer->startSerialization();
	//registering a name is optional, it allows you to retrieve the shape by name
	//serializer->registerNameForPointer(trimeshShape,"mymesh");
#ifdef SERIALIZE_SHAPE
	trimeshShape->serializeSingleShape(serializer);
#else
	trimeshShape->serializeSingleBvh(serializer);
#endif
	serializer->finishSerialization();
	FILE* f2 = fopen("myShape.bullet","wb");
	fwrite(serializer->getBufferPointer(),serializer->getCurrentBufferSize(),1,f2);
	fclose(f2);

#else
	btBulletWorldImporter import(0);//don't store info into the world
	if (import.loadFile("myShape.bullet"))
	{
		int numBvh = import.getNumBvhs();
		if (numBvh)
		{
			btOptimizedBvh* bvh = import.getBvhByIndex(0);
			btVector3 aabbMin(-1000,-1000,-1000),aabbMax(1000,1000,1000);
	
			trimeshShape  = new btBvhTriangleMeshShape(m_indexVertexArrays,useQuantizedAabbCompression,aabbMin,aabbMax,false);
			trimeshShape->setOptimizedBvh(bvh);
			//trimeshShape  = new btBvhTriangleMeshShape(m_indexVertexArrays,useQuantizedAabbCompression,aabbMin,aabbMax);
			//trimeshShape->setOptimizedBvh(bvh);
	
		}
		int numShape = import.getNumCollisionShapes();
		if (numShape)
		{
			trimeshShape = (btBvhTriangleMeshShape*)import.getCollisionShapeByIndex(0);
			
			//if you know the name, you can also try to get the shape by name:
			const char* meshName = import.getNameForPointer(trimeshShape);
			if (meshName)
				trimeshShape = (btBvhTriangleMeshShape*)import.getCollisionShapeByName(meshName);
			
		}
	}


#endif

	btCollisionShape* groundShape = trimeshShape;
	
#else
	btCollisionShape* groundShape = new btBoxShape(btVector3(50,3,50));
	m_collisionShapes.push_back(groundShape);

#endif //USE_TRIMESH_SHAPE

	m_collisionConfiguration = new btDefaultCollisionConfiguration();

#ifdef USE_PARALLEL_DISPATCHER

#ifdef USE_WIN32_THREADING

	int maxNumOutstandingTasks = 4;//number of maximum outstanding tasks
	Win32ThreadSupport* threadSupport = new Win32ThreadSupport(Win32ThreadSupport::Win32ThreadConstructionInfo(
								"collision",
								processCollisionTask,
								createCollisionLocalStoreMemory,
								maxNumOutstandingTasks));
#else
///@todo show other platform threading
///Playstation 3 SPU (SPURS)  version is available through PS3 Devnet
///Libspe2 SPU support will be available soon
///pthreads version
///you can hook it up to your custom task scheduler by deriving from btThreadSupportInterface
#endif

	m_dispatcher = new	SpuGatheringCollisionDispatcher(threadSupport,maxNumOutstandingTasks,m_collisionConfiguration);
#else
	m_dispatcher = new	btCollisionDispatcher(m_collisionConfiguration);
#endif//USE_PARALLEL_DISPATCHER

#ifdef USE_PARALLEL_DISPATCHER
	m_dynamicsWorld->getDispatchInfo().m_enableSPU=true;
#endif //USE_PARALLEL_DISPATCHER
	
	float mass = 0.f;
	btTransform	startTransform;
	startTransform.setIdentity();
	startTransform.setOrigin(btVector3(0,-2,0));

	startTransform.setIdentity();
	staticBody = LocalCreateRigidBody(mass, startTransform,groundShape);

	staticBody->setCollisionFlags(staticBody->getCollisionFlags() | btCollisionObject::CF_KINEMATIC_OBJECT);//STATIC_OBJECT);

	//enable custom material callback
	staticBody->setCollisionFlags(staticBody->getCollisionFlags() | btCollisionObject::CF_CUSTOM_MATERIAL_CALLBACK);
}

void Terrain::setDirectionToSun(const D3DXVECTOR3& v)
{
	D3DXVECTOR4 temp(v.x, v.y, v.z, 0.0f);
	mfxDirToSunVar->SetFloatVector((float*)temp);
}

void Terrain::draw(const D3DXMATRIX& world)
{
	md3dDevice->IASetInputLayout(InputLayout::PosNormalTex);

	UINT stride = sizeof(TerrainVertex);
    UINT offset = 0;
    md3dDevice->IASetVertexBuffers(0, 1, &mVB, &stride, &offset);
	md3dDevice->IASetIndexBuffer(mIB, DXGI_FORMAT_R32_UINT, 0);

	D3DXMATRIX view = GetCamera().View();
	D3DXMATRIX proj = GetCamera().Proj();

	D3DXMATRIX WVP = world*view*proj;


	mfxWVPVar->SetMatrix((float*)&WVP);
	mfxWorldVar->SetMatrix((float*)&world);

	mfxLayer0Var->SetResource(mLayer0);
	mfxLayer1Var->SetResource(mLayer1);
	mfxLayer2Var->SetResource(mLayer2);
	mfxLayer3Var->SetResource(mLayer3);
	mfxLayer4Var->SetResource(mLayer4);
	mfxBlendMapVar->SetResource(mBlendMap);

    D3D10_TECHNIQUE_DESC techDesc;
    mTech->GetDesc( &techDesc );

    for(UINT i = 0; i < techDesc.Passes; ++i)
    {
        ID3D10EffectPass* pass = mTech->GetPassByIndex(i);
		pass->Apply(0);

		md3dDevice->DrawIndexed(mNumFaces*3, 0, 0);
	}	
}

void Terrain::loadHeightmap()
{
	// A height for each vertex
	std::vector<unsigned char> in( mInfo.NumRows * mInfo.NumCols );

	// Open the file.
	std::ifstream inFile;
	inFile.open(mInfo.HeightmapFilename.c_str(), std::ios_base::binary);

	if(inFile)
	{
		// Read the RAW bytes.
		inFile.read((char*)&in[0], (std::streamsize)in.size());

		// Done with file.
		inFile.close();
	}

	// Copy the array data into a float array, and scale and offset the heights.
	mHeightmap.resize(mInfo.NumRows * mInfo.NumCols, 0);
	for(UINT i = 0; i < mInfo.NumRows * mInfo.NumCols; ++i)
	{
		mHeightmap[i] = (float)in[i] * mInfo.HeightScale + mInfo.HeightOffset;
	}
}

void Terrain::smooth()
{
	std::vector<float> dest( mHeightmap.size() );

	for(UINT i = 0; i < mInfo.NumRows; ++i)
	{
		for(UINT j = 0; j < mInfo.NumCols; ++j)
		{
			dest[i*mInfo.NumCols+j] = average(i,j);
		}
	}

	// Replace the old heightmap with the filtered one.
	mHeightmap = dest;
}

bool Terrain::inBounds(UINT i, UINT j)
{
	// True if ij are valid indices; false otherwise.
	return 
		i >= 0 && i < mInfo.NumRows && 
		j >= 0 && j < mInfo.NumCols;
}

float Terrain::average(UINT i, UINT j)
{
	// Function computes the average height of the ij element.
	// It averages itself with its eight neighbor pixels.  Note
	// that if a pixel is missing neighbor, we just don't include it
	// in the average--that is, edge pixels don't have a neighbor pixel.
	//
	// ----------
	// | 1| 2| 3|
	// ----------
	// |4 |ij| 6|
	// ----------
	// | 7| 8| 9|
	// ----------

	float avg = 0.0f;
	float num = 0.0f;

	for(UINT m = i-1; m <= i+1; ++m)
	{
		for(UINT n = j-1; n <= j+1; ++n)
		{
			if( inBounds(m,n) )
			{
				avg += mHeightmap[m*mInfo.NumCols + n];
				num += 1.0f;
			}
		}
	}

	return avg / num;
}

void Terrain::buildVB()
{
	//std::vector<TerrainVertex> vertices(mNumVertices);

	float halfWidth = (mInfo.NumCols-1)*mInfo.CellSpacing*0.5f;
	float halfDepth = (mInfo.NumRows-1)*mInfo.CellSpacing*0.5f;

	float du = 1.0f / (mInfo.NumCols-1);
	float dv = 1.0f / (mInfo.NumRows-1);
	for(UINT i = 0; i < mInfo.NumRows; ++i)
	{
		float z = halfDepth - i*mInfo.CellSpacing;
		for(UINT j = 0; j < mInfo.NumCols; ++j)
		{
			TerrainVertex v;

			float x = -halfWidth + j*mInfo.CellSpacing;

			float y = mHeightmap[i*mInfo.NumCols+j];
			v.pos    = D3DXVECTOR3(x, y, z);
			v.normal = D3DXVECTOR3(0.0f, 1.0f, 0.0f);

			// Stretch texture over grid.
			v.texC.x = j*du;
			v.texC.y = i*dv;

			mVertices.push_back(v);
		}
	}
 
	// Estimate normals for interior nodes using central difference.
	float invTwoDX = 1.0f / (2.0f*mInfo.CellSpacing);
	float invTwoDZ = 1.0f / (2.0f*mInfo.CellSpacing);
	for(UINT i = 2; i < mInfo.NumRows-1; ++i)
	{
		for(UINT j = 2; j < mInfo.NumCols-1; ++j)
		{
			float t = mHeightmap[(i-1)*mInfo.NumCols + j];
			float b = mHeightmap[(i+1)*mInfo.NumCols + j];
			float l = mHeightmap[i*mInfo.NumCols + j - 1];
			float r = mHeightmap[i*mInfo.NumCols + j + 1];

			D3DXVECTOR3 tanZ(0.0f, (t-b)*invTwoDZ, 1.0f);
			D3DXVECTOR3 tanX(1.0f, (r-l)*invTwoDX, 0.0f);

			D3DXVECTOR3 N;
			D3DXVec3Cross(&N, &tanZ, &tanX);
			D3DXVec3Normalize(&N, &N);

			mVertices[i*mInfo.NumCols+j].normal = N;
		}
	}

    D3D10_BUFFER_DESC vbd;
    vbd.Usage = D3D10_USAGE_IMMUTABLE;
    vbd.ByteWidth = sizeof(TerrainVertex) * mNumVertices;
    vbd.BindFlags = D3D10_BIND_VERTEX_BUFFER;
    vbd.CPUAccessFlags = 0;
    vbd.MiscFlags = 0;
	D3D10_SUBRESOURCE_DATA vinitData;
    vinitData.pSysMem = &mVertices[0];
    HR(md3dDevice->CreateBuffer(&vbd, &vinitData, &mVB));
}

void Terrain::buildIB()
{
	//std::vector<DWORD> indices(mNumFaces*3); // 3 indices per face

	// Iterate over each quad and compute indices.
	int k = 0;
	for(UINT i = 0; i < mInfo.NumRows-1; ++i)
	{
		for(UINT j = 0; j < mInfo.NumCols-1; ++j)
		{
			mIndices.push_back( i * mInfo.NumCols + j );
			mIndices.push_back( i * mInfo.NumCols + j + 1 );
			mIndices.push_back( ( i + 1 ) * mInfo.NumCols + j );

			mIndices.push_back( ( i + 1 ) * mInfo.NumCols + j );
			mIndices.push_back( i * mInfo.NumCols + j + 1 );
			mIndices.push_back( ( i + 1 ) * mInfo.NumCols + j + 1 );

			k += 6; // next quad
		}
	}

	D3D10_BUFFER_DESC ibd;
    ibd.Usage = D3D10_USAGE_IMMUTABLE;
    ibd.ByteWidth = sizeof(DWORD) * mNumFaces*3;
    ibd.BindFlags = D3D10_BIND_INDEX_BUFFER;
    ibd.CPUAccessFlags = 0;
    ibd.MiscFlags = 0;
    D3D10_SUBRESOURCE_DATA iinitData;
    iinitData.pSysMem = &mIndices[0];
    HR(md3dDevice->CreateBuffer(&ibd, &iinitData, &mIB));
}

std::vector<TerrainVertex> Terrain::GetVertices()
{
	return mVertices;
}

std::vector<int> Terrain::GetIndices()
{
	return mIndices;
}

UINT Terrain::GetNumVertices()
{
	return mNumVertices;
}

int Terrain::GetNumIndices()
{
	return mNumFaces * 3;
}

btRigidBody* Terrain::LocalCreateRigidBody(float mass, const btTransform& startTransform,btCollisionShape* shape)
{
	btAssert((!shape || shape->getShapeType() != INVALID_SHAPE_PROXYTYPE));

	//rigidbody is dynamic if and only if mass is non zero, otherwise static
	bool isDynamic = (mass != 0.f);

	btVector3 localInertia(0,0,0);
	if (isDynamic)
		shape->calculateLocalInertia(mass,localInertia);

	//using motionstate is recommended, it provides interpolation capabilities, and only synchronizes 'active' objects

#define USE_MOTIONSTATE 1
#ifdef USE_MOTIONSTATE
	btDefaultMotionState* myMotionState = new btDefaultMotionState(startTransform);

	btRigidBody::btRigidBodyConstructionInfo cInfo(mass,myMotionState,shape,localInertia);

	btRigidBody* body = new btRigidBody(cInfo);
	body->setContactProcessingThreshold(mDefaultContactProcessingThreshold);

#else
	btRigidBody* body = new btRigidBody(mass,0,shape,localInertia);	
	body->setWorldTransform(startTransform);
#endif//

	mDynamicsWorld->addRigidBody(body);

	return body;
}